use self::spirv_grammar::SpirvGrammar;
use foldhash::HashMap;
use nom::{
    bytes::complete::{tag, take_until},
    character::complete::{self, multispace0, multispace1},
    combinator::eof,
    sequence::{delimited, tuple},
    IResult, Parser,
};
use proc_macro2::TokenStream;
use std::{
    cmp::min,
    fs::File,
    io::{BufWriter, Write},
    ops::BitOrAssign,
    path::Path,
};
use vk_parse::{
    Command, Enum, EnumSpec, Enums, EnumsChild, Extension, ExtensionChild, Feature, Format,
    InterfaceItem, Registry, RegistryChild, SpirvExtOrCap, Type, TypeSpec, TypesChild,
};

mod conjunctive_normal_form;
mod errors;
mod extensions;
mod features;
mod fns;
mod formats;
mod properties;
mod spirv_grammar;
mod spirv_parse;
mod spirv_reqs;
mod version;

const INPUT_DIR: &str = concat!(std::env!("CARGO_MANIFEST_DIR"), "/../autogen");
const OUTPUT_DIR: &str = concat!(std::env!("CARGO_MANIFEST_DIR"), "/../vulkano/autogen-out/");

pub(crate) type IndexMap<K, V> = indexmap::IndexMap<K, V, foldhash::fast::RandomState>;

fn main() {
    let input_dir = Path::new(INPUT_DIR);

    let registry = get_vk_registry(&input_dir.join("vk.xml"));
    let vk_data = VkRegistryData::new(&registry);
    let spirv_grammar = get_spirv_grammar(&input_dir.join("spirv.core.grammar.json"));

    errors::write(&vk_data);
    extensions::write(&vk_data);
    features::write(&vk_data);
    formats::write(&vk_data);
    fns::write(&vk_data);
    properties::write(&vk_data);
    spirv_parse::write(&spirv_grammar);
    spirv_reqs::write(&vk_data, &spirv_grammar);
    version::write(&vk_data);
}

fn write_file(file: impl AsRef<Path>, source: impl AsRef<str>, contents: TokenStream) {
    let contents = prettyplease::unparse(&syn::parse2(contents).unwrap());

    let path = Path::new(OUTPUT_DIR).join(file.as_ref());
    let mut writer = BufWriter::new(File::create(&path).unwrap());

    write!(
        writer,
        "\
        // This file is auto-generated by vulkano autogen from {}.\n\
        // It should not be edited manually. Changes should be made by editing autogen.\n\
        \n\n{}",
        source.as_ref(),
        contents,
    )
    .unwrap();
}

fn get_vk_registry<P: AsRef<Path> + ?Sized>(path: &P) -> Registry {
    let (registry, errors) = vk_parse::parse_file(path.as_ref()).unwrap();

    if !errors.is_empty() {
        eprintln!("The following errors were found while parsing the file:");

        for error in errors {
            eprintln!("{:?}", error);
        }
    }

    registry
}

pub struct VkRegistryData<'r> {
    pub header_version: (u16, u16, u16),
    pub errors: Vec<&'r str>,
    pub extensions: IndexMap<&'r str, &'r Extension>,
    pub features: IndexMap<&'r str, &'r Feature>,
    pub formats: Vec<&'r Format>,
    pub spirv_capabilities: Vec<&'r SpirvExtOrCap>,
    pub spirv_extensions: Vec<&'r SpirvExtOrCap>,
    pub types: HashMap<&'r str, (&'r Type, Vec<&'r str>)>,
    pub commands: IndexMap<&'r str, &'r Command>,
}

impl<'r> VkRegistryData<'r> {
    fn new(registry: &'r Registry) -> Self {
        let aliases = Self::get_aliases(registry);
        let extensions = Self::get_extensions(registry);
        let features = Self::get_features(registry);
        let formats = Self::get_formats(registry);
        let spirv_capabilities = Self::get_spirv_capabilities(registry);
        let spirv_extensions = Self::get_spirv_extensions(registry);
        let errors = Self::get_errors(registry, &features, &extensions);
        let types = Self::get_types(registry, &aliases, &features, &extensions);
        let header_version = Self::get_header_version(registry);
        let commands = Self::get_commands(registry);

        VkRegistryData {
            header_version,
            errors,
            extensions,
            features,
            formats,
            spirv_capabilities,
            spirv_extensions,
            types,
            commands,
        }
    }

    /// Returns the Vulkan header version in the vk.xml file.
    fn get_header_version(registry: &Registry) -> (u16, u16, u16) {
        fn spaced_comma(input: &str) -> IResult<&str, char> {
            delimited(multispace0, complete::char(','), multispace0)(input)
        }

        fn vk_header_patch(input: &str) -> IResult<&str, u16> {
            let (input, _) = take_until("#define")(input)?;
            delimited(
                tuple((
                    tag("#define"),
                    multispace1,
                    tag("VK_HEADER_VERSION"),
                    multispace0,
                )),
                complete::u16,
                tuple((multispace0, eof)),
            )(input)
        }

        fn vk_header_major_minor(input: &str) -> IResult<&str, (u16, u16)> {
            let (input, _) = take_until("#define")(input)?;
            delimited(
                tuple((
                    tag("#define"),
                    multispace1,
                    tag("VK_HEADER_VERSION_COMPLETE"),
                    multispace1,
                    tag("VK_MAKE_API_VERSION"),
                    multispace0,
                    complete::char('('),
                    multispace0,
                )),
                tuple((
                    complete::u16,
                    spaced_comma,
                    complete::u16,
                    spaced_comma,
                    complete::u16,
                    spaced_comma,
                    tag("VK_HEADER_VERSION"),
                ))
                .map(|(_ignored, _, major, _, minor, _, _)| (major, minor)),
                tuple((multispace0, complete::char(')'), multispace0)),
            )(input)
        }

        let mut major = None;
        let mut minor = None;
        let mut patch = None;

        for child in registry.0.iter() {
            if let RegistryChild::Types(types) = child {
                for ty in types.children.iter() {
                    if let TypesChild::Type(ty) = ty {
                        if ty.api.as_deref() != Some("vulkan") {
                            continue;
                        }
                        if let TypeSpec::Code(code) = &ty.spec {
                            if let Ok((_, p)) = vk_header_patch(&code.code) {
                                assert!(patch.is_none());
                                patch = Some(p);
                            } else if let Ok((_, (m, n))) = vk_header_major_minor(&code.code) {
                                assert!(major.is_none());
                                major = Some(m);
                                minor = Some(n);
                            }
                        }
                    }
                }
            }
        }

        (major.unwrap(), minor.unwrap(), patch.unwrap())
    }

    fn get_aliases(registry: &Registry) -> HashMap<&str, &str> {
        registry
            .0
            .iter()
            .filter_map(|child| {
                if let RegistryChild::Types(types) = child {
                    return Some(types.children.iter().filter_map(|ty| {
                        if let TypesChild::Type(ty) = ty {
                            if let Some(alias) = ty.alias.as_deref() {
                                return Some((ty.name.as_ref().unwrap().as_str(), alias));
                            }
                        }
                        None
                    }));
                }
                None
            })
            .flatten()
            .collect()
    }

    fn get_errors<'a>(
        registry: &'a Registry,
        features: &IndexMap<&'a str, &'a Feature>,
        extensions: &IndexMap<&'a str, &'a Extension>,
    ) -> Vec<&'a str> {
        registry
            .0
            .iter()
            .filter_map(|child| match child {
                RegistryChild::Enums(Enums {
                    name: Some(name),
                    children,
                    ..
                }) if name == "VkResult" => Some(children.iter().filter_map(|en| {
                    if let EnumsChild::Enum(en) = en {
                        if let EnumSpec::Value { value, .. } = &en.spec {
                            // Treat NotReady and Timeout as error conditions
                            if value.starts_with('-')
                                || matches!(en.name.as_str(), "VK_NOT_READY" | "VK_TIMEOUT")
                            {
                                return Some(en.name.as_str());
                            }
                        }
                    }
                    None
                })),
                _ => None,
            })
            .flatten()
            .chain(
                features
                    .values()
                    .map(|feature| feature.children.iter())
                    .chain(
                        extensions
                            .values()
                            .map(|extension| extension.children.iter()),
                    )
                    .flatten()
                    .filter_map(|child| {
                        if let ExtensionChild::Require { items, .. } = child {
                            return Some(items.iter().filter_map(|item| match item {
                                InterfaceItem::Enum(Enum {
                                    name,
                                    spec:
                                        EnumSpec::Offset {
                                            extends,
                                            dir: false,
                                            ..
                                        },
                                    ..
                                }) if extends == "VkResult" => Some(name.as_str()),
                                _ => None,
                            }));
                        }
                        None
                    })
                    .flatten(),
            )
            .collect()
    }

    fn get_extensions(registry: &Registry) -> IndexMap<&str, &Extension> {
        let iter = registry
            .0
            .iter()
            .filter_map(|child| {
                if let RegistryChild::Extensions(ext) = child {
                    return Some(ext.children.iter().filter(|ext| {
                        if ext
                            .supported
                            .as_deref()
                            .is_some_and(|s| s.split(',').any(|s| s == "vulkan"))
                            && ext.obsoletedby.is_none()
                        {
                            return true;
                        }
                        false
                    }));
                }
                None
            })
            .flatten();

        let extensions: HashMap<&str, &Extension> =
            iter.clone().map(|ext| (ext.name.as_str(), ext)).collect();
        let mut names: Vec<_> = iter.map(|ext| ext.name.as_str()).collect();
        names.sort_unstable_by_key(|name| {
            if name.starts_with("VK_KHR_") {
                (0, name.to_owned())
            } else if name.starts_with("VK_EXT_") {
                (1, name.to_owned())
            } else {
                (2, name.to_owned())
            }
        });

        names.iter().map(|&name| (name, extensions[name])).collect()
    }

    fn get_features(registry: &Registry) -> IndexMap<&str, &Feature> {
        registry
            .0
            .iter()
            .filter_map(|child| {
                if let RegistryChild::Feature(feat) = child {
                    if feat.api.split(',').any(|s| s == "vulkan") {
                        return Some((feat.name.as_str(), feat));
                    }
                }

                None
            })
            .collect()
    }

    fn get_formats(registry: &Registry) -> Vec<&Format> {
        registry
            .0
            .iter()
            .filter_map(|child| {
                if let RegistryChild::Formats(formats) = child {
                    return Some(formats.children.iter());
                }
                None
            })
            .flatten()
            .collect()
    }

    fn get_spirv_capabilities(registry: &Registry) -> Vec<&SpirvExtOrCap> {
        registry
            .0
            .iter()
            .filter_map(|child| {
                if let RegistryChild::SpirvCapabilities(capabilities) = child {
                    return Some(capabilities.children.iter());
                }
                None
            })
            .flatten()
            .collect()
    }

    fn get_spirv_extensions(registry: &Registry) -> Vec<&SpirvExtOrCap> {
        registry
            .0
            .iter()
            .filter_map(|child| {
                if let RegistryChild::SpirvExtensions(extensions) = child {
                    return Some(extensions.children.iter());
                }
                None
            })
            .flatten()
            .collect()
    }

    fn get_types<'a>(
        registry: &'a Registry,
        aliases: &HashMap<&'a str, &'a str>,
        features: &IndexMap<&'a str, &'a Feature>,
        extensions: &IndexMap<&'a str, &'a Extension>,
    ) -> HashMap<&'a str, (&'a Type, Vec<&'a str>)> {
        let mut types: HashMap<&str, (&Type, Vec<&str>)> = registry
            .0
            .iter()
            .filter_map(|child| {
                if let RegistryChild::Types(types) = child {
                    return Some(types.children.iter().filter_map(|ty| {
                        if let TypesChild::Type(ty) = ty {
                            if ty.alias.is_none() {
                                return ty.name.as_ref().map(|name| (name.as_str(), (ty, vec![])));
                            }
                        }
                        None
                    }));
                }
                None
            })
            .flatten()
            .collect();

        features
            .iter()
            .map(|(name, feature)| (name, &feature.children))
            .chain(extensions.iter().map(|(name, ext)| (name, &ext.children)))
            .for_each(|(provided_by, children)| {
                children
                    .iter()
                    .filter_map(|child| {
                        if let ExtensionChild::Require { items, .. } = child {
                            return Some(items.iter());
                        }
                        None
                    })
                    .flatten()
                    .filter_map(|item| {
                        if let InterfaceItem::Type { name, .. } = item {
                            return Some(name.as_str());
                        }
                        None
                    })
                    .for_each(|item_name| {
                        let item_name = aliases.get(item_name).unwrap_or(&item_name);
                        if let Some(ty) = types.get_mut(item_name) {
                            if !ty.1.contains(provided_by) {
                                ty.1.push(provided_by);
                            }
                        }
                    });
            });

        types
            .into_iter()
            .filter(|(_key, val)| !val.1.is_empty())
            .collect()
    }

    fn get_commands(registry: &Registry) -> IndexMap<&str, &Command> {
        registry
            .0
            .iter()
            .filter_map(|child| {
                // TODO: resolve aliases into CommandDefinition immediately?
                if let RegistryChild::Commands(commands) = child {
                    return Some(commands.children.iter().map(|c| {
                        (
                            match c {
                                Command::Alias { name, .. } => name.as_str(),
                                Command::Definition(d) => d.proto.name.as_str(),
                                _ => todo!(),
                            },
                            c,
                        )
                    }));
                }
                None
            })
            .flatten()
            .collect()
    }
}

pub fn get_spirv_grammar<P: AsRef<Path> + ?Sized>(path: &P) -> SpirvGrammar {
    let mut grammar = SpirvGrammar::new(path);

    // Remove duplicate opcodes and enum values, preferring "more official" suffixes
    grammar
        .instructions
        .sort_by_key(|instruction| (instruction.opcode, suffix_key(&instruction.opname)));
    grammar
        .instructions
        .dedup_by_key(|instruction| instruction.opcode);

    grammar
        .operand_kinds
        .iter_mut()
        .filter(|operand_kind| operand_kind.category == "BitEnum")
        .for_each(|operand_kind| {
            operand_kind.enumerants.sort_by_key(|enumerant| {
                let value = enumerant
                    .value
                    .as_str()
                    .unwrap()
                    .strip_prefix("0x")
                    .unwrap();
                (
                    u32::from_str_radix(value, 16).unwrap(),
                    suffix_key(&enumerant.enumerant),
                )
            });
        });

    grammar
        .operand_kinds
        .iter_mut()
        .filter(|operand_kind| operand_kind.category == "ValueEnum")
        .for_each(|operand_kind| {
            operand_kind.enumerants.sort_by_key(|enumerant| {
                (enumerant.value.as_u64(), suffix_key(&enumerant.enumerant))
            });
        });

    grammar
}

fn suffix_key(name: &str) -> u32 {
    #[allow(clippy::bool_to_int_with_if)]
    if name.ends_with("AMD")
        || name.ends_with("GOOGLE")
        || name.ends_with("INTEL")
        || name.ends_with("NV")
    {
        3
    } else if name.ends_with("EXT") {
        2
    } else if name.ends_with("KHR") {
        1
    } else {
        0
    }
}

#[derive(Clone, Debug, Default, PartialEq, Eq)]
pub struct RequiresOneOf {
    pub api_version: Option<(u32, u32)>,
    pub device_extensions: Vec<String>,
    pub instance_extensions: Vec<String>,
    pub device_features: Vec<String>,
}

impl RequiresOneOf {
    pub fn is_empty(&self) -> bool {
        let Self {
            api_version,
            device_extensions,
            instance_extensions,
            device_features,
        } = self;

        api_version.is_none()
            && device_extensions.is_empty()
            && instance_extensions.is_empty()
            && device_features.is_empty()
    }
}

impl BitOrAssign<&Self> for RequiresOneOf {
    fn bitor_assign(&mut self, rhs: &Self) {
        self.api_version = match (self.api_version, rhs.api_version) {
            (None, None) => None,
            (None, Some(x)) | (Some(x), None) => Some(x),
            (Some(lhs), Some(rhs)) => Some(min(lhs, rhs)),
        };

        for rhs in &rhs.device_extensions {
            if !self.device_extensions.contains(rhs) {
                self.device_extensions.push(rhs.to_owned());
            }
        }

        for rhs in &rhs.instance_extensions {
            if !self.instance_extensions.contains(rhs) {
                self.instance_extensions.push(rhs.to_owned());
            }
        }

        for rhs in &rhs.device_features {
            if !self.device_features.contains(rhs) {
                self.device_features.push(rhs.to_owned());
            }
        }
    }
}
